--- title: Transformations for 3D medical images keywords: fastai sidebar: home_sidebar nb_path: "nbs/02_transforms.ipynb" ---
Using the @patch decorator makes the transform function a callable for the Tensor class (and all subclasses, including TensorDicom3D and TensorMask3D).
Resizer = Resize3D((10,50,50))
original = TensorDicom3D.create('/media/ScaleOut/prostata/data/dcm/A0042197734/T2/DICOM')
mask = TensorMask3D.create('/media/ScaleOut/prostata/data/dcm/A0042197734/T2/Annotation/cropped_mask.nii.gz')
original.show()
mask.show()
im = Resizer(original, split_idx = 0)
ma = Resizer(mask, split_idx = 0)
im.show()
ma.show()
In medical images, the left and right side often cannot be differentiated from each other (e.g. scans of the head, hand, knee, ...). Therfore the image orientation is stored in the image header, enabeling the viewer system to accuratly display the images. For deep learning, only the pixel array is extracted, so the header information is lost. When displaying only the pixel array, the images might already be displayed flipped or in inverted slice order. So, implementing vertical/horizontal flipping as well as flipping alongside the z axis can be used for data augmentation.
show_images_3d(torch.stack((im, flip_ll_3d(im), flip_ap_3d(im), flip_cc_3d(im))))
flipper = RandomFlip3D()
flipper(im)
show_images_3d(torch.stack((im, flipper(im, split_idx = 0), flipper(im, split_idx = 0), flipper(im, split_idx = 0))))
Medical images should show no rotation, however with removal of the image file header, the pixel array might appear rotated when displayed and thus be introduced to the model rotated. Fruthermore, in some images the patients might be rotated to some degree. Thus rotation of 90 and 180° as well as substeps should be implemented.
rotator = RandomRotate3D()
show_images_3d(torch.stack((im, rotate_90_3d(im), rotate_180_3d(im), rotate_270_3d(im),
rotator(im, split_idx = 0), rotator(im, split_idx = 0), rotator(im, split_idx = 0))))
rotator_by = RandomRotate3DBy()
show_images_3d(torch.stack((im, rotate_3d_by(im, angle = 15, axes = [1,2]), rotator_by(im, split_idx = 0))))
im2 = resize_3d(readdcm_3d('/media/ScaleOut/prostata/data/dcm/A0042197734/T2/DICOM', return_normalized = True), (25, 10, 50))
show_images_3d(torch.stack((im2, rotate_3d_by(im2, angle =10, axes = [0,2]))), axis =1, nrow = 5)
Rotating by 90 (or 180 and 270) degrees should not be done via rotate_3d_by but by rotate_90_3d, is approximatly 28 times faster.
As the 3D array can be flipped by three sides, but should only be rotated along the z axis, this is not a complete dihedral group. Still multiple combinations of flipping and rotating should be implemented:
I am not sure if this is complete...
dihedral = RandomDihedral3D()
show_images_3d(torch.stack((im, dihedral(im, split_idx = 0), dihedral(im, split_idx = 0),
dihedral(im, split_idx = 0),dihedral(im, split_idx = 0),
dihedral(im, split_idx = 0))))
A reasonable approach for 3D medical images would be a presizing to uniform but to large volume and subsequent random cropping to the target dimension. As most areas of interest are located centrally in the image/volume some cropping can always be applied.
Also random cropping should be applied after any rotation, that is not in 90/180/270 degrees, so that empty margins are cropped.
Cropper = RandomCrop3D((0,10,10), (0,5,5), False)
show_images_3d(torch.stack((Cropper(im, split_idx = 0), Cropper(im, split_idx = 0),
Cropper(im, split_idx = 0), Cropper(im, split_idx = 0))))
Other cropping methods, with padding, squishing or large maginifcations might not be appropriate for medical images, since often only small areas in the image are of importance which could be removed by cropping (e.g. tumor). So cropping should only be applied to the image margins.
As cropping a resizing are good preprocessing operations, they can be merged into one class, for easier access.
warper = RandomWarp3D(p=1)
show_images_3d(torch.stack((warper(im, split_idx = 0), warper(im, split_idx = 0), warper(im, split_idx = 0), warper(im, split_idx = 0), warper(im, split_idx = 0))))
noise_adder = RandomNoise3D(p=1)
show_images_3d(torch.stack((noise_adder(im, split_idx = 0), noise_adder(im, split_idx = 0),
noise_adder(im, split_idx = 0), noise_adder(im, split_idx = 0), noise_adder(im, split_idx = 0))))
lighting = RandomBrightness3D()
show_images_3d(torch.stack((im, lighting(im, split_idx = 0), lighting(im, split_idx = 0),
lighting(im, split_idx = 0), lighting(im, split_idx = 0),
lighting(im, split_idx = 0), lighting(im, split_idx = 0))))
contrast = RandomContrast3D()
show_images_3d(torch.stack((im, contrast(im, split_idx = 0), contrast(im, split_idx = 0),
contrast(im, split_idx = 0), contrast(im, split_idx = 0),
contrast(im, split_idx = 0), contrast(im, split_idx = 0))))
A good workflow would be to apply random crop to all images after one transformation. For this, the images should be presized to a size, just some pixels larger then desired, then transformed and then cropped to the final size. Using this approach empty space, which e.g. appears after RandomRotate3DBy will be cropped and not influence the accuracy of the model. One only has to be careful, that the region of interest, e.g. the prostate, will be in every cropped image.
im = readdcm_3d('/media/ScaleOut/prostata/data/dcm/A0042197734/T2/DICOM', return_normalized = True)
im = resize_3d(im, (30, 250, 250)) # presizing the images
Cropper = RandomCrop3D((5,50,50), (1,5,5))
tfms = [RandomBrightness3D(), RandomContrast3D(), RandomWarp3D(), RandomDihedral3D(), RandomNoise3D(), RandomRotate3DBy()]
tfms = [Pipeline([RandomBrightness3D, Cropper], split_idx = 0),
Pipeline([RandomContrast3D, Cropper], split_idx = 0),
Pipeline([RandomWarp3D, Cropper], split_idx = 0),
Pipeline([RandomDihedral3D, Cropper], split_idx = 0),
Pipeline([RandomNoise3D, Cropper], split_idx = 0),
Pipeline([RandomRotate3DBy, Cropper], split_idx = 0)]
comp = setup_aug_tfms(tfms)
ims = [t(im) for t in tfms]
show_images_3d(torch.stack(ims))
MakeColor = PseudoColor()
im.shape, MakeColor(im, split_idx = 0).shape
aug_transforms_3d()
# Is now implemented as a callback
def _make_binary(t, set_to_one):
"Sets all but one values to zero. The remaining value is set to one."
return (t == set_to_one).float().to(t.device)
@patch
def to_one_hot(m:(Tensor,TensorMask3D), num_features:int):
"""
Takes a Tensor and will return a one hot encoded version,
where every layer of the 2nd channel corresponds to a single
one hot encoded value.
Args:
m: a Tensor or TensorMask3D in the Format: B*C*D*H*W where C should be 1
num_features: number of features to be one_hot_encoded
Returns:
A one hot encoded tensor with the number of color channels corresponding to num_features
"""
m = m.squeeze(1).long() # remove the solitary color channel (if there is one) and set type to int64
one_hot = [_make_binary(m, set_to_one=i) for i in range(0, num_features + 1)]
return torch.stack(one_hot, 1).to(m.device)
class MaskOneHot(RandTransform):
split_idx, p = 1, 1
def __init__(self, p=1):
super().__init__(p=p)
def __call__(self, b, split_idx=1, **kwargs):
"change in __call__ to enforce, that the Transform is always applied on every dataset. "
return super().__call__(b, split_idx=split_idx, **kwargs)
def encodes(self, x:(TensorMask3D)):
return x.to_one_hot()